#include "statsxx/machine_learning/DBN.hpp"

// STL
#include <string>  // std::string, std::to_string()
#include <utility> // std::make_pair()
#include <vector>  // std::vector<>

// jScience
#include "jScience/linalg.hpp" // Matrix<>


// using namespace machine_learning;


//
// DESC: New DBN
//
// NOTE: the training proceeds by the greedy training algorithm:
//
//     TODO: add ref
//
inline void machine_learning::DBN::train(
                                         const int                  nepoch,
                                         const int                  nbatches,
                                         // ---
                                         const int                  K_start,            // starting number of CD steps
                                         const int                  K_end,              // ending number
                                         const double               K_rate,             // rate at which K_start transitions to K_end
                                         // ---
                                         const bool                 sample_v,           // whether to sample v during reconstructions --- this is correct and may lead to better density models (Hinton, practical guide), but not reduces sampling noise thus allowing faster learning
                                         const bool                 sample_hdata,       // whether to sample h for the positive statistics --- true is closer to the mathematical model of an RBM
                                         // ---
                                         double                     lr,
                                         const std::vector<double> &lr_dot,
                                         const std::vector<double> &lr_adj,
                                         const double               lr_min,
                                         const double               lr_max,
                                         // ---
                                         const double               momentum_start,
                                         const double               momentum_end,
                                         const double               momentum_rate,
                                         const std::vector<double> &momentum_dot,
                                         const std::vector<double> &momentum_adj,
                                         // ---
                                         const double               weight_penalty,
                                         // ---
                                         const int                  free_energy_npts,   // number of training points to (randomly) select for free energy
                                         const int                  free_energy_nepoch, // evaluate the free energy every number of epochs
                                         const int                  free_energy_window, // calculate the slope of the free energy over this window
                                         const double               convg_criterion_max_inc_max_w,
                                         const int                  convg_criterion_nno_improvement,
                                         // ---
                                         Matrix<double>             X,                  // training set
                                         Matrix<double>             X_val,              // validation set
                                         // ---
                                         const std::string          prefix
                                         )
{
    for(auto i = 0; i < this->RBM.size(); ++i)
    {
        this->RBM[i].train(
                           nepoch,
                           nbatches,
                           // ---
                           K_start,
                           K_end,
                           K_rate,
                           // ---
                           sample_v,
                           sample_hdata,
                           // ---
                           lr,
                           lr_dot,
                           lr_adj,
                           lr_min,
                           lr_max,
                           // ---
                           momentum_start,
                           momentum_end,
                           momentum_rate,
                           momentum_dot,
                           momentum_adj,
                           // ---
                           weight_penalty,
                           // ---
                           free_energy_npts,
                           free_energy_nepoch,
                           free_energy_window,
                           convg_criterion_max_inc_max_w,
                           convg_criterion_nno_improvement,
                           // ---
                           X,
                           X_val,
                           // ---
                           (prefix + ".RBM." + std::to_string(i))
                           );

        // ----- GET NEW TRAINING SET -----
        if(i < (this->RBM.size()-1))
        {
            X = this->RBM[i].v_to_h(
                                    X
                                    );
            
            X_val = this->RBM[i].v_to_h(
                                        X_val
                                        );
        }
    }
}